#include "netlib/net/socket/socket_ops.h"

#include <sys/socket.h>
#include <unistd.h>

#include <cassert>
#include <cstring>

#include "netlib/base/logger.h"
#include "netlib/net/call_backs.h"
#include "netlib/net/socket/inet_address.h"

void netlib::net::sockets::BindOrDie(int fd, const sockaddr* addr) {
	int res = bind(fd, addr, sizeof(*addr));
	if (res < 0) {
		LOG_SYSFATAL << "sockets::bindOrDie";
	}
}

const sockaddr* netlib::net::sockets::SockaddrCast(const sockaddr_in* addr) {
	return static_cast<const sockaddr*>(ImplicitCast<const void*>(addr));
}

const sockaddr* netlib::net::sockets::SockaddrCast(const sockaddr_in6* addr) {
	return static_cast<const sockaddr*>(ImplicitCast<const void*>(addr));
}

void netlib::net::sockets::ListenOrDie(int fd) {
	int res = listen(fd, 4096);
	if (res < 0) {
		LOG_SYSFATAL << "sockets::listenOrDie";
	}
}

int netlib::net::sockets::Accept(int sockfd, struct sockaddr_in6* addr) {
	auto addrlen = static_cast<socklen_t>(sizeof *addr);
#if VALGRIND || defined(NO_ACCEPT4)
	int connfd = ::accept(sockfd, sockaddr_cast(addr), &addrlen);
	setNonBlockAndCloseOnExec(connfd);
#else
	int connfd = ::accept4(sockfd, const_cast<sockaddr*>(SockaddrCast(addr)), &addrlen,
	                       SOCK_NONBLOCK | SOCK_CLOEXEC);
#endif
	if (connfd < 0) {
		int saved_errno = errno;
		LOG_SYSERR << "Socket::accept";
		switch (saved_errno) {
		case EAGAIN:
		case ECONNABORTED:
		case EINTR:
		case EPROTO: // ???
		case EPERM:
		case EMFILE: // per-process lmit of open file
		             // desctiptor ???
			// expected errors
			errno = saved_errno;
			break;
		case EBADF:
		case EFAULT:
		case EINVAL:
		case ENFILE:
		case ENOBUFS:
		case ENOMEM:
		case ENOTSOCK:
		case EOPNOTSUPP:
			// unexpected errors
			LOG_FATAL << "unexpected error of ::accept " << saved_errno;
			break;
		default:
			LOG_FATAL << "unknown error of ::accept " << saved_errno;
			break;
		}
	}
	return connfd;
}

int netlib::net::sockets::Connect(int sockfd, const struct sockaddr* addr) {
	int connfd = ::connect(sockfd, addr, sizeof(*addr));
	if (connfd) {
		// TODO(longfar): NONBLOCK
	}
	return connfd;
}

void netlib::net::sockets::FromIpPort(const char* ip, uint16_t port, sockaddr_in* addr) {
	addr->sin_family = AF_INET;
	addr->sin_port = htons(port);
	if (inet_pton(AF_INET, ip, &addr->sin_addr) <= 0) {
		LOG_SYSERR << "sockets::fromIpPort";
	}
}

void netlib::net::sockets::FromIpPort(const char* ip, uint16_t port, sockaddr_in6* addr) {
	addr->sin6_family = AF_INET6;
	addr->sin6_port = htons(port);
	if (inet_pton(AF_INET6, ip, &addr->sin6_addr) <= 0) {
		LOG_SYSERR << "sockets::fromIpPort6";
	}
}

netlib::net::InetAddress netlib::net::sockets::GetLocalAddress(int fd) {
	sockaddr_in6 addr;
	socklen_t sz = sizeof(addr);
	memset(&addr, '\0', sz);
	int res = ::getsockname(fd, const_cast<sockaddr*>(SockaddrCast(&addr)), &sz);
	if (res < 0) {
		LOG_SYSERR << "sockets::getLocalAddress()";
	}
	return InetAddress(addr);
}

void netlib::net::sockets::ToIpPort(char* buf, size_t size, const struct sockaddr* addr) {
	if (addr->sa_family == AF_INET6) {
		buf[0] = '[';
		ToIp(buf + 1, size - 1, addr);
		size_t end = ::strlen(buf);
		const struct sockaddr_in6* addr6 = SockaddrIn6Cast(addr);
		uint16_t port = sockets::NetworkToHost16(addr6->sin6_port);
		assert(size > end);
		snprintf(buf + end, size - end, "]:%u", port);
		return;
	}
	ToIp(buf, size, addr);
	size_t end = ::strlen(buf);
	const struct sockaddr_in* addr4 = SockaddrInCast(addr);
	uint16_t port = sockets::NetworkToHost16(addr4->sin_port);
	assert(size > end);
	snprintf(buf + end, size - end, ":%u", port);
}

void netlib::net::sockets::ToIp(char* buf, size_t size, const struct sockaddr* addr) {
	if (addr->sa_family == AF_INET) {
		assert(size >= INET_ADDRSTRLEN);
		const struct sockaddr_in* addr4 = SockaddrInCast(addr);
		::inet_ntop(AF_INET, &addr4->sin_addr, buf, static_cast<socklen_t>(size));
	} else if (addr->sa_family == AF_INET6) {
		assert(size >= INET6_ADDRSTRLEN);
		const struct sockaddr_in6* addr6 = SockaddrIn6Cast(addr);
		::inet_ntop(AF_INET6, &addr6->sin6_addr, buf, static_cast<socklen_t>(size));
	}
}

const sockaddr_in* netlib::net::sockets::SockaddrInCast(const struct sockaddr* addr) {
	return static_cast<const struct sockaddr_in*>(ImplicitCast<const void*>(addr));
}

const sockaddr_in6* netlib::net::sockets::SockaddrIn6Cast(const struct sockaddr* addr) {
	return static_cast<const struct sockaddr_in6*>(ImplicitCast<const void*>(addr));
}

int netlib::net::sockets::GetSocketError(int sockfd) {
	int optval;
	auto optlen = static_cast<socklen_t>(sizeof optval);

	if (::getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &optval, &optlen) < 0) {
		return errno;
	}
	return optval;
}

void netlib::net::sockets::ShutdownWrite(int sockfd) {
	if (::shutdown(sockfd, SHUT_WR) < 0) {
	}
}

void netlib::net::sockets::Close(int sockfd) {
	if (::close(sockfd) < 0) {
		LOG_SYSERR << "sockets::close()";
	}
}

struct sockaddr_in6 netlib::net::sockets::GetLocalAddr(int sockfd) {
	struct sockaddr_in6 localaddr;
	memset(&localaddr, '\0', sizeof localaddr);
	auto addrlen = static_cast<socklen_t>(sizeof localaddr);
	if (::getsockname(sockfd, const_cast<sockaddr*>(SockaddrCast(&localaddr)), &addrlen) < 0) {
		LOG_SYSERR << "sockets::getLocalAddr";
	}
	return localaddr;
}

struct sockaddr_in6 netlib::net::sockets::GetPeerAddr(int sockfd) {
	struct sockaddr_in6 peeraddr;
	memset(&peeraddr, '\0', sizeof peeraddr);
	auto addrlen = static_cast<socklen_t>(sizeof peeraddr);
	if (::getpeername(sockfd, const_cast<sockaddr*>(SockaddrCast(&peeraddr)), &addrlen) < 0) {
		LOG_SYSERR << "sockets::getPeerAddr";
	}
	return peeraddr;
}

bool netlib::net::sockets::IsSelfConnect(int sockfd) {
	struct sockaddr_in6 localaddr = GetLocalAddr(sockfd);
	struct sockaddr_in6 peeraddr = GetPeerAddr(sockfd);
	if (localaddr.sin6_family == AF_INET) {
		const struct sockaddr_in* laddr4 = reinterpret_cast<struct sockaddr_in*>(&localaddr);
		const struct sockaddr_in* raddr4 = reinterpret_cast<struct sockaddr_in*>(&peeraddr);
		return laddr4->sin_port == raddr4->sin_port &&
		       laddr4->sin_addr.s_addr == raddr4->sin_addr.s_addr;
	}
	if (localaddr.sin6_family == AF_INET6) {
		return localaddr.sin6_port == peeraddr.sin6_port &&
		       memcmp(&localaddr.sin6_addr, &peeraddr.sin6_addr, sizeof localaddr.sin6_addr) == 0;
	}
	return false;
}